import torch
import numpy as np
import argparse
from utils.load_dataset import *
from utils.instantiate_model import *
from utils.str2bool import str2bool
from utils.measures import get_measures
from methods.msp import msp
from methods.msp_polar import msp as msp_polar
from methods.odin import ODIN, set_ODIN_hyperparams_for_indist
from methods.odin_polar import ODIN as ODIN_polar
from methods.odin_polar import set_ODIN_hyperparams_for_indist as set_ODIN_hyperparams_for_indist_polar
from methods.energy import energy_score
from methods.energy_polar import energy_score as energy_score_polar
import os
import multiprocessing
from torch.utils.tensorboard import SummaryWriter

print_string = "{: <12}\t\t|" + "\t{:.4f} $ {:.4f}\t\t{:.4f} $ {:.4f}  \t|" * 3

def get_metric(in_dataset, ood_dataset, mode, suffix):
    # Instantiate model 
    net, net_name = instantiate_model(dataset=dataset,
                                      arch=args.arch,
                                      suffix=mode + "_" + suffix,
                                      load=True,
                                      torch_weights=False,
                                      device=device)
    net.eval()

    if(args.method.lower() == 'odin'):
        set_ODIN_hyperparams_for_indist(dataset, net, net_name, device)
    if(args.method.lower() == 'odin_polar'):
        set_ODIN_hyperparams_for_indist_polar(dataset, net, net_name, device)

    labels, pred = ood_technique_func(in_dataset, ood_dataset, net, device)
    return get_measures(labels, pred)

def get_metric_for_dataset(dataset):

    suffixs = args.suffixs.split(',')

    print('------------------------------------------------------------------------------', file=out_file)
    print('\t'*8 + dataset.name, file=out_file)
    print('------------------------------------------------------------------------------', file=out_file)
    print("OoD Dataset  \t\t|\t\t\t\t\tAUROC\t\t\t\t\t|\t\t\t\t\tAUPR\t\t\t\t\t|\t\t\t\t\tFPR95\t\t\t\t\t|", file=out_file)
    print("             \t\t|" + "\t\t\tERM\t\t\t\t\tVRM\t\t\t|" * 3, file=out_file)

    sum_erm = np.array([])
    sum_vrm = np.array([])
    no_ood_datasets = 0
    for ood_dataset_name in ood_datasets:
        if(ood_dataset_name.lower() == dataset.name):
            continue

        if('cifar' in dataset.name and 'cifar' in ood_dataset_name.lower()):
            continue

        ood_dataset = load_dataset(dataset=ood_dataset_name,
                                   train_batch_size=args.train_batch_size,
                                   test_batch_size=args.test_batch_size,
                                   val_split=args.val_split,
                                   augment=args.augment,
                                   padding_crop=args.padding_crop,
                                   shuffle=args.shuffle,
                                   random_seed=args.random_seed,
                                   device=device,
                                   mean=dataset.mean,
                                   std=dataset.std)

        vrm = []
        erm = []
        no_ood_datasets +=1 
        for algo in ['vrm', 'mixup']:
            for suffix in suffixs:
                auroc, aupr, fpr = get_metric(dataset,
                                              ood_dataset,
                                              mode=algo,
                                              suffix=str(suffix))

                if algo == 'vrm':
                    vrm.append([auroc, aupr, fpr])
                else:
                    erm.append([auroc, aupr, fpr])

        vrm = np.array(vrm)
        erm = np.array(erm)
        vrm_values = vrm.mean(axis=0)
        erm_values = erm.mean(axis=0)
        vrm_std = vrm.std(axis=0)
        erm_std = erm.std(axis=0)

        if(len(sum_vrm) == 0):
            sum_vrm = np.expand_dims(vrm_values, 0)
            sum_erm = np.expand_dims(erm_values, 0)
        else:
            sum_vrm = np.append(sum_vrm, np.expand_dims(vrm_values, 0), axis=0)
            sum_erm = np.append(sum_erm, np.expand_dims(erm_values, 0), axis=0)

        print(print_string.format(ood_dataset.name,
                                  erm_values[0],
                                  erm_std[0],
                                  vrm_values[0],
                                  vrm_std[0],
                                  erm_values[1],
                                  erm_std[1],
                                  vrm_values[1],
                                  vrm_std[1],
                                  erm_values[2],
                                  erm_std[2],
                                  vrm_values[2],
                                  vrm_std[2],
                                )
              , file=out_file)

    sum_vrm_std = sum_vrm.std(axis=0)
    sum_erm_std = sum_erm.std(axis=0)
    sum_vrm_mean = sum_vrm.mean(axis=0)
    sum_erm_mean = sum_erm.mean(axis=0)

    print(print_string.format("Average",
                              sum_erm_mean[0],
                              sum_erm_std[0],
                              sum_vrm_mean[0],
                              sum_vrm_std[0],
                              sum_erm_mean[1],
                              sum_erm_std[1],
                              sum_vrm_mean[1],
                              sum_vrm_std[1],
                              sum_erm_mean[2],
                              sum_erm_std[2],
                              sum_vrm_mean[2],
                              sum_vrm_std[2],
                            )
          , file=out_file)
    print("", file=out_file)
    
    return sum_vrm, sum_erm


if __name__ == "__main__":
    if os.name == 'nt':
        # On Windows calling this function is necessary for multiprocessing
        multiprocessing.freeze_support()

    parser = argparse.ArgumentParser(description='Train', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--parallel',               default=False,          type=str2bool,  help='Device in  parallel')

    # Dataloader args
    parser.add_argument('--train_batch_size',       default=1256,           type=int,       help='Train batch size')
    parser.add_argument('--test_batch_size',        default=1256,           type=int,       help='Test batch size')
    parser.add_argument('--val_split',              default=0.1,            type=float,     help='Fraction of training dataset split as validation')
    parser.add_argument('--augment',                default=True,           type=str2bool,  help='Random horizontal flip and random crop')
    parser.add_argument('--padding_crop',           default=4,              type=int,       help='Padding for random crop')
    parser.add_argument('--shuffle',                default=True,           type=str2bool,  help='Shuffle the training dataset')
    parser.add_argument('--random_seed',            default=0,              type=int,       help='Initialising the seed for reproducibility')
    parser.add_argument('--arch',                   default='resnet18',     type=str,       help='Network architecture')
    parser.add_argument('--suffixs',                default='1,2,3,4,5',    type=str,       help='Model suffixs')
    parser.add_argument('--method',                 default='msp_polar',         type=str,       help='Comparison with other OoD methods')
    parser.add_argument('--outfile',                default='comp.out',     type=str,       help='Name of the output file to store results from comparison with other OoD methods')

    global args
    args = parser.parse_args()
    print(args)

    # Setup right device to run on
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    writer = SummaryWriter(comment="Grad mask")

    in_datasets = ['cifar10', 'cifar100', 'svhn', 'tinyimagenet']
    ood_datasets = ['g-noise', 'u-noise', 'svhn', 'cifar100', 'textures', 'lsun', 'tinyimagenet', 'places365']
    ood_technique_func = None
    out_file = open(args.outfile, "w")
    print("Saving result to file {}".format(args.outfile))

    if(args.method.lower() == 'msp'):
        ood_technique_func = msp
    elif(args.method.lower() == 'odin'):
        ood_technique_func = ODIN
    elif(args.method.lower() == 'energy_score'):
        ood_technique_func = energy_score
    elif(args.method.lower() == 'msp_polar'):
        ood_technique_func = msp_polar
    elif(args.method.lower() == 'odin_polar'):
        ood_technique_func = ODIN_polar
    elif(args.method.lower() == 'energy_score_polar'):
        ood_technique_func = energy_score_polar
    else:
        # Right way to handle exception in python 
        # see https://stackoverflow.com/questions/2052390/manually-raising-throwing-an-exception-in-python
        # Explains all the traps of using exception, does a good job!! I mean the link :)
        raise ValueError("Unsupported OoD detection method")

    print('\n')

    dataset_vrm = np.array([])
    dataset_erm = np.array([])

    for in_dataset in in_datasets:
        dataset = load_dataset(dataset=in_dataset,
                               train_batch_size=args.train_batch_size,
                               test_batch_size=args.test_batch_size,
                               val_split=args.val_split,
                               augment=args.augment,
                               padding_crop=args.padding_crop,
                               shuffle=args.shuffle,
                               random_seed=args.random_seed,
                               device=device)
        
        vrm, erm = get_metric_for_dataset(dataset)
        if(len(dataset_vrm) == 0):
            dataset_vrm = np.expand_dims(vrm, 0)
            dataset_erm = np.expand_dims(erm, 0)
        else:
            dataset_vrm = np.append(dataset_vrm, np.expand_dims(vrm, 0), axis=0)
            dataset_erm = np.append(dataset_erm, np.expand_dims(erm, 0), axis=0)

        print(vrm)
        print(erm)
        print("{} Done".format(in_dataset))

    dataset_vrm_std = dataset_vrm.std(axis=(0,1))
    dataset_erm_std = dataset_erm.std(axis=(0,1))
    dataset_vrm_mean = dataset_vrm.mean(axis=(0,1))
    dataset_erm_mean = dataset_erm.mean(axis=(0,1))

    print('------------------------------------------------------------------------------', file=out_file)
    print(print_string.format("Average",
                              dataset_erm_mean[0],
                              dataset_erm_std[0],
                              dataset_vrm_mean[0],
                              dataset_vrm_std[0],
                              dataset_erm_mean[1],
                              dataset_erm_std[1],
                              dataset_vrm_mean[1],
                              dataset_vrm_std[1],
                              dataset_erm_mean[2],
                              dataset_erm_std[2],
                              dataset_vrm_mean[2],
                              dataset_vrm_std[2],
                            )
            , file=out_file)
    print("", file=out_file)

    out_file.close()
